1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
| import torch import torch.nn as nn import torch.nn.functional as F
class MultiHeadAttention(nn.Module): def __init__(self, embed_size=512, num_heads=8): super(MultiHeadAttention, self).__init__() self.embed_size = embed_size self.num_heads = num_heads self.head_dim = embed_size // num_heads
assert self.head_dim * num_heads == embed_size, "Embed size must be divisible by num_heads"
self.query = nn.Linear(embed_size, embed_size) self.key = nn.Linear(embed_size, embed_size) self.value = nn.Linear(embed_size, embed_size) self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, x, mask=None): """ 参数: x: 输入张量 (batch_size, seq_len, embed_size) mask: 支持两种掩码类型: - Padding mask: (batch_size, 1, 1, seq_len) - Sequence mask: (batch_size, 1, seq_len, seq_len) """ batch_size, seq_len, _ = x.shape Q = self.query(x) K = self.key(x) V = self.value(x) Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3) K = K.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3) V = V.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3) scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (self.head_dim ** 0.5) if mask is not None: """ 掩码处理规则: 1. 自动适配不同维度的mask输入 2. 将mask转换为布尔类型 3. 在注意力分数上应用mask """ if mask.dim() == 2: mask = mask.unsqueeze(1).unsqueeze(1) elif mask.dim() == 3: mask = mask.unsqueeze(1) mask = mask.to(torch.bool) scores = scores.masked_fill(mask, -1e9) attention_weights = F.softmax(scores, dim=-1) context = torch.matmul(attention_weights, V) context = context.permute(0, 2, 1, 3).contiguous() context = context.view(batch_size, seq_len, self.embed_size) output = self.fc_out(context) return output, attention_weights
if __name__ == "__main__": embed_size = 8 num_heads = 2 seq_len = 4 batch_size = 1
mha = MultiHeadAttention(embed_size, num_heads) x = torch.rand(batch_size, seq_len, embed_size) print("===== 测试1: 无掩码 =====") out, attn = mha(x) print(f"注意力权重形状: {attn.shape}") print("第一个头的注意力矩阵:") print(attn[0, 0].detach().numpy().round(3)) print("\n===== 测试2: 应用Padding掩码 =====") padding_mask = torch.tensor([[1, 1, 0, 0]], dtype=torch.bool) _, attn_pad = mha(x, mask=padding_mask) print("带padding mask的注意力矩阵:") print(attn_pad[0, 0].detach().numpy().round(3)) print("\n===== 测试3: 应用Sequence掩码 =====") seq_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool() seq_mask = seq_mask.unsqueeze(0) _, attn_seq = mha(x, mask=seq_mask) print("带sequence mask的注意力矩阵:") print(attn_seq[0, 0].detach().numpy().round(3))
|